[学习笔记]Splay平衡树
[学习笔记]Splay平衡树
这个代码拖得也是够久了。。一个月之前Treap写炸了之后就写了份Splay的代码,但总是毛病众多,结果直到今天才算是调出来。
算法介绍
伸展树(Splay Tree),也叫分裂树,是一种能自我平衡的二叉查找树,它能在均摊\(O(\log n)\)的时间内完成基于伸展树的插入、查找、修改和删除操作。Splay的格局,是和别处不同的(雾。Splay的目的并不是维持树的平衡,而是将上次访问的数放在最近的地方,方便下次访问。经过数学论证,可以得出其均摊时间复杂度仍为\(O(\log n)\)。
首先,因为Splay仍然是一棵平衡树,所以很显然会用常见的rotate()——旋转操作。翻转操作一般分为左旋转和右旋转,其作用是在满足二叉查找树的性质的前提下,将某个节点于其父节点进行“位置的交换”,具体效果分别如下:
旋转操作
因此,Splay树中最为核心的便是Splay()——伸展操作。其作用是将一个节点上移到根节点,方便下次进行访问。Splay()具体进行的操作需要根据情况而定,可分为以下三类:
假设我们操作的是node节点。此外,我们定义一个relation()函数,返回该节点是其父亲的左儿子还是右儿子。
- node节点的父亲节点就是根节点:
rotate(node);; 
node->relation() == node->fa->relation(): rotate(x->fa); rotate(x);; 
node->realtion() != node->fa->relation(): roate(x); rotate(x); 
这便是Splay中比较基础的两个操作,其他具体用到的函数一般都会调用到这两个函数。其它具体函数的原理会在下面讨论。
实现细节
代码参考: Menci
节点(Node)
这是Node结构体的定义,提供了所需的属性和基本的函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| struct Node { int value; int size, cnt;
Node *fa, *son[2];
bool relation() { return this == this->fa->son[R]; }
Node(const int &val = 0, Node *f = NULL) :value(val), size(1), cnt(1), fa(f) { son[L] = son[R] = NULL; }
~Node() { if (son[L]) delete son[L]; if (son[R]) delete son[R]; } };
|
旋转(rotate)
旋转操作的基本原理已于上处讲过,不再赘述。这里的旋转将左旋和右旋结合在了一起。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| void rotate(Node *x) { Node *old = x->fa; int flag = x->relation();
if (old->fa) old->fa->son[ old->relation() ] = x; x->fa = old->fa;
if (x->son[ flag ^ 1 ]) x->son[ flag ^ 1 ]->fa = old; old->son[ flag ] = x->son[ flag ^ 1 ];
old->fa = x; x->son[ flag ^ 1 ] = old;
update(old); update(x);
if (x->fa==NULL) root = x; return; }
|
伸展(splay)
Splay操作也如上文所说,不再赘述。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| Node *splay(Node *x, Node *target = NULL) { if (!x) return x;
while (x->fa != target) { if (x->fa->fa == target) rotate(x); else if (x->fa->relation() == x->relation() ) { rotate(x->fa); rotate(x); }else { rotate(x); rotate(x); } }
return x; }
|
更新(update)
update()函数负责对节点的size进行更新。实现也很简单:
1 2 3 4 5 6 7 8 9 10 11 12 13
| void update(Node *x) { if (!x) return;
x->size = x->cnt; if (x->son[L]) x->size += x->son[L]->size; if (x->son[R]) x->size += x->son[R]->size; return; }
|
查找(find)
查找也是一个基础操作,按照常规的二叉查找树的搜索方法即可:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| Node *find(const int &value) { Node *x = root;
while (x && x->value!=value) {
if (value < x->value) x = x->son[L]; else x = x->son[R]; }
if (!x) return NULL;
splay(x);
return x; }
|
插入(insert)
插入是平衡树中很基本的操作。在Splay中,我们在插入前首先需要考虑该数值在树内是否存在。如果存在,我们只需要让cnt++即可;但若是不存在,我们便需要先查找到适合插入的位置,然后新建节点插入,最后在伸展才可以。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| Node *insert(const int &value) { if (root==NULL) { root = new Node(value, NULL); return root; }
Node *x = find(value); if (x) { x->cnt ++; x->size ++;
return x; }
Node *target = root, *parent = NULL; bool mark;
while (target) { parent = target; parent->size ++;
if (value < parent->value) target = parent->son[mark = L]; else target = parent->son[mark = R]; }
target = new Node(value, parent); if (parent) parent->son[ mark ] = target;
splay(target);
return target; }
|
删除(erase)
删除操作和插入操作也类似,我们需要考虑删除的数是否只存在一个。若是存在多个,我们只需要让cnt--即可。但是若是存在一个,我们便需要通过一个较为复杂的方法对节点进行删除。
首先,很明显,我们不能直接对一个节点进行删除,这样的话可能会破坏树的结构。那么,我们的思路便是令我们想要删除的节点node不存在子节点,然后再删除。对此,我们有一个较为简单的方案:先将node的前趋节点移到根节点,再将node的后继节点移到根节点的右儿子的位置。此时,步骤如下,可以看出node节点变成了其后继的儿子,且其没有任何儿子,所以我们可以直接进行删除。
[](
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| void erase(Node *x) { if (x->cnt > 1) { splay(x); x->cnt --; x->size --;
return; }
Node *pre = pred(x); Node *suc = succ(x);
splay(pre); splay(suc, pre);
delete x; if (x->fa) x->fa->son[ x->relation() ] = NULL; if (x==root) root = NULL;
update(suc); update(pre);
return; }
|
对数的删除操作只需要对上个函数进行调用就好了:
1 2 3 4 5 6 7
| void erase(const int &value) { erase(find(value));
return; }
|
前趋(pred)
因为我们每个节存储的是当前数值的所有数,而不是一个数,所以求前趋就非常简单了。在大多数情况下,一个数的前趋就是这个点的前趋。我们只需要求其左儿子,然后若存在右儿子,一直向下求即可:
1 2 3 4 5 6 7 8 9 10 11 12
| Node *pred(Node *x) { Node *pre = x->son[L];
if (!pre) return NULL; while (pre->son[R]) pre = pre->son[R]; return pre; }
|
当然,可能会存在一个节点没有左儿子的情况,这种情况下,一个节点的前趋一般为其父亲。不过,由于我们在这里一般调用的是pred(int):int,而在这个函数里(具体来说,是在其中调用的find()函数里),我们对该节点进行了一次splay(),所以其一定为根节点,所以不需要考虑这种情况。
在pred(int):int中,我们还需要多考虑一种情况:当想要查询的数值不存在时。这是我们需要手动插入一个节点,然后查询,最后再将这个节点删除。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| const int &pred(const int &value) { Node *x = find(value);
if (x) return pred(x)->value; else { x = insert(value); const int &ans = pred(x)->value; erase(x);
return ans; } }
|
后继(succ)
后继和前趋类似,分为节点查询和数查询两个函数:
1 2 3 4 5 6 7 8 9 10 11 12 13
| Node *succ(Node *x) { Node *suc = x->son[R];
if (!suc) return NULL; while (suc->son[L]) suc = suc->son[L];
return suc; }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| const int &succ(const int &value) {
Node *x = find(value);
if (x) return succ(x)->value; else { x = insert(value); const int &ans = succ(x)->value; erase(x);
return ans; } }
|
查询数的排名(rank)
查询排名也分为了两个函数,一个是查询节点排名,一个查询数的排名。对于节点,和前趋和后继一样,因为在rank(int)中伸展过,我们还是只考虑当前节点已被移到根节点的情况,较为简单,所以函数如下:
1 2 3 4 5
| int rank(Node *x) { return (x->son[L]==NULL) ? 0: x->son[L]->size; }
|
对于数,我们也还是需要考虑当前数不存在的情况:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| int rank(const int &value) { Node *x = find(value);
if (x) return rank(x); else { x = insert(value); int ans = rank(x); erase(x);
return ans; } }
|
查询特定排名的节点或数(select)
select(int)函数主要依赖rank(int)来运作。这里需要注意,因为在查询过程中我们没有进行Splay(),所以说rank(Node*)代表的并不是该节点的排名,而是其左子树的大小。所以我们在向右子树搜索的时候需要将当前节点和左子树的大小减去,然后搜索。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| Node *select(int k) { Node *x = root; while ( !( rank(x)+1 <= k && (rank(x) + x->cnt >=k) ) ) { if (k < rank(x) + 1) x = x->son[L]; else { k -= rank(x) + x->cnt; x = x->son[R]; } }
splay(x);
return x; }
|
例题
Luogu 3369
题目来源: Luogu
非常裸的一道平衡树,涉及了很多基本的操作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
| #include <iostream>
#define L 0 #define R 1
using namespace std;
struct Node { int value; int size, cnt;
Node *fa, *son[2];
bool relation() { return this == this->fa->son[R]; }
Node(const int &val = 0, Node *f = NULL) :value(val), size(1), cnt(1), fa(f) { son[L] = son[R] = NULL; }
~Node() { if (son[L]) delete son[L]; if (son[R]) delete son[R]; } };
struct Splay { Node *root;
Splay() { root = NULL; }
~Splay() { delete root; }
void update(Node *x) { if (!x) return;
x->size = x->cnt; if (x->son[L]) x->size += x->son[L]->size; if (x->son[R]) x->size += x->son[R]->size;
return; }
void rotate(Node *x) { Node *old = x->fa; int flag = x->relation(); if (old->fa) old->fa->son[ old->relation() ] = x; x->fa = old->fa;
if (x->son[ flag ^ 1 ]) x->son[ flag ^ 1 ]->fa = old; old->son[ flag ] = x->son[ flag ^ 1 ];
old->fa = x; x->son[ flag ^ 1 ] = old;
update(old); update(x);
if (x->fa==NULL) root = x;
return; }
Node *splay(Node *x, Node *target = NULL) { if (!x) return x;
while (x->fa != target) { if (x->fa->fa == target) rotate(x); else if (x->fa->relation() == x->relation() ) { rotate(x->fa); rotate(x); }else { rotate(x); rotate(x); } }
return x; }
Node *pred(Node *x) { Node *pre = x->son[L];
if (!pre) return NULL;
while (pre->son[R]) pre = pre->son[R];
return pre; }
Node *succ(Node *x) { Node *suc = x->son[R];
if (!suc) return NULL; while (suc->son[L]) suc = suc->son[L];
return suc; }
int rank(Node *x) { return (x->son[L]==NULL) ? 0: x->son[L]->size; }
Node *find(const int &value) { Node *x = root;
while (x && x->value!=value) {
if (value < x->value) x = x->son[L]; else x = x->son[R]; }
if (!x) return NULL;
splay(x);
return x; }
Node *insert(const int &value) { if (root==NULL) { root = new Node(value, NULL); return root; }
Node *x = find(value); if (x) { x->cnt ++; x->size ++;
return x; }
Node *target = root, *parent = NULL; bool mark;
while (target) { parent = target; parent->size ++;
if (value < parent->value) target = parent->son[mark = L]; else target = parent->son[mark = R]; }
target = new Node(value, parent); if (parent) parent->son[ mark ] = target;
splay(target);
return target; }
void erase(Node *x) { if (x->cnt > 1) { splay(x); x->cnt --; x->size --;
return; }
Node *pre = pred(x); Node *suc = succ(x);
splay(pre); splay(suc, pre);
delete x; if (x->fa) x->fa->son[ x->relation() ] = NULL; if (x==root) root = NULL;
update(suc); update(pre);
return; }
void erase(const int &value) { erase(find(value));
return; }
int rank(const int &value) { Node *x = find(value);
if (x) return rank(x); else { x = insert(value); int ans = rank(x); erase(x);
return ans; } }
Node *select(int k) {
Node *x = root; while ( !( rank(x)+1 <= k && (rank(x) + x->cnt >=k) ) ) { if (k < rank(x) + 1) x = x->son[L]; else { k -= rank(x) + x->cnt; x = x->son[R]; } }
splay(x);
return x; }
const int &pred(const int &value) { Node *x = find(value);
if (x) return pred(x)->value; else { x = insert(value); const int &ans = pred(x)->value; erase(x);
return ans; } }
const int &succ(const int &value) {
Node *x = find(value);
if (x) return succ(x)->value; else { x = insert(value); const int &ans = succ(x)->value; erase(x);
return ans; } }
void print(Node *x) { if (x) { cout << "(" << x->value << ", " << x->cnt << ")-[ "; print(x->son[L]); cout << ", "; print(x->son[R]); cout << "]"; } }
}*splay;
int main() { splay = new Splay();
int n; cin >> n;
for (int i = 1; i<=n; i++) { int opt; cin >> opt;
int x; switch(opt) { case 1: cin >> x; splay->insert(x); break; case 2: cin >> x; splay->erase(x); break; case 3: cin >> x; cout << splay->rank(x)+1 << endl; break; case 4: cin >> x; cout << splay->select(x)->value << endl; break; case 5: cin >> x; cout << splay->pred(x) << endl; break; case 6: cin >> x; cout << splay->succ(x) << endl; break; default: cout << "Error: No such operation!" << endl; } } }
|
Luogu 3391
题目来源: Luogu